Personality Prediction with LIME

1 Model Explainability with LIME

In this notebook, I use LIME (Local Interpretable Model-Agnostic Explanations) to interpret the predictions of my PyTorch neural network, trained on a tabular dataset to classify individuals as Introverts or Extroverts. This model is related to the Playground S5E7 competition, linked here.

Check out the LIME source code on GitHub here.

1.1 What is LIME?

Machine learning models, especially neural networks, often act as black boxes: they can achieve strong predictive performance, but it is difficult to understand why a specific prediction was made.

LIME addresses this challenge by creating local surrogate models. The key idea is: 1. Take a single prediction we want to explain. 2. Generate many small perturbations of the input data point. 3. Run these perturbed samples through the original black-box model. 4. Fit a simple and interpretable model (like linear regression) to approximate the black-box model’s behavior locally around that input. 5. Use the weights of this surrogate model to highlight which features were most influential in the prediction.

Because LIME focuses only on the local neighborhood of a given observation, the explanation is tailored to that specific prediction rather than the model as a whole.


1.2 Why is LIME Important?

  • Transparency: LIME makes model decisions understandable to humans, even when the underlying model is complex (e.g., deep learning).
  • Trust: Stakeholders (such as employers, clinicians, or end-users) are more likely to trust a model if its predictions can be explained.
  • Debugging: By showing which features drive predictions, LIME helps identify if the model relies on irrelevant or biased inputs.
  • Model-Agnostic: LIME works with any classifier or regressor (PyTorch, TensorFlow, scikit-learn, etc.) as long as the model provides predictions.

For this project, applying LIME allows me to highlight which psychological traits most strongly influenced the prediction of introversion vs. extroversion for each individual.


1.3 How LIME Fits Into the Workflow

  1. Train a PyTorch model on the dataset (done earlier).
  2. Wrap the model’s predict function so it accepts a NumPy array and outputs class probabilities.
  3. Initialize LimeTabularExplainer with the training data, feature names, and class labels.
  4. Explain individual predictions by calling explain_instance on specific test samples.
  5. Visualize the results in notebook tables, plots, or HTML files.

1.4 Code Overview

Code
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
from lime.lime_tabular import LimeTabularExplainer
import pandas as pd
Code
class PersonalityModel(nn.Module):
    def __init__(self, input_features, output_features, hidden_features = 32):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_features = input_features, out_features = hidden_features),
            nn.BatchNorm1d(hidden_features),
            nn.GELU(),
            nn.Dropout(p=0.2),
            nn.Linear(in_features = hidden_features, out_features = hidden_features),
            nn.BatchNorm1d(hidden_features),
            nn.GELU(),
            nn.Dropout(p=0.3),
            nn.Linear(in_features = hidden_features, out_features = hidden_features),
            nn.SiLU(),
            nn.Dropout(p=0.4),
            nn.Linear(in_features = hidden_features, out_features = output_features)
        )

    def forward(self, x):
        return self.layers(x)
Code
# 0. Retrieve model weights and load data
model = PersonalityModel(input_features = 20, output_features = 1, hidden_features = 512)
model.load_state_dict(torch.load("model_weights.pth"))

X_train = pd.read_csv('X_train_processed.csv').to_numpy(dtype=np.float32)
X_test = pd.read_csv('X_test_processed.csv').to_numpy(dtype=np.float32)
feature_names = pd.read_csv('X_train_processed.csv').columns

# 1. Wrap the PyTorch model for LIME
model.eval()
def predict_fn(x: np.ndarray):
    with torch.no_grad():
        inputs = torch.tensor(x, dtype=torch.float32)
        logits = model(inputs).squeeze()   # shape (n_samples,)
        prob_pos = torch.sigmoid(logits)   # probability of class 1
        prob_neg = 1 - prob_pos            # probability of class 0
        probs = torch.stack([prob_neg, prob_pos], dim=1)
    return probs.numpy()                   # return as numpy array

# 2. Initialize the explainer
explainer = LimeTabularExplainer(
    training_data=X_train,
    feature_names=feature_names,
    class_names=["Introvert", "Extrovert"],
    mode="classification"
)

# 3. Explain a single prediction
i = 0  # index of sample to explain
exp = explainer.explain_instance(
    data_row=X_test[i],
    predict_fn=predict_fn,
    num_features=10
)

# 4. Display results
exp.show_in_notebook(show_table=True)
print(exp.as_list())
[('match_p <= -0.10', 0.23165595245151016), ('alone_flag <= -0.25', 0.11018292086153071), ('Social_event_attendance_isna <= -0.26', 0.10042801373324538), ('Drained_after_socializing <= 0.26', 0.05818020989834077), ('Post_frequency_isna <= -0.27', 0.0497936014927899), ('Stage_fear <= 0.34', 0.04408518791699877), ('Time_spent_Alone <= 0.25', 0.04076641769609262), ('introversion_score <= 0.25', 0.03849683230630708), ('Time_spent_Alone_isna > -0.26', -0.02653337407055524), ('Friends_circle_size_isna <= -0.25', -0.026286142400476813)]

1.5 Explaining Low-Confidence Predictions with LIME

To better understand where the model is uncertain, we: 1. Computed predicted probabilities for all test samples. 2. Identified samples closest to 0.5 (least confident predictions). 3. Applied LIME to these low-confidence samples to see which features influenced the model’s decision.

LIME provides: - Feature importance scores for each individual prediction. - The actual feature values, so we can compare model reasoning to real-world behavior.

This approach is particularly useful for: - Debugging the model. - Understanding edge cases in personality prediction. - Demonstrating responsible and interpretable AI practices in a portfolio.

Code
# --- Step 1: Compute probabilities for the test set ---
with torch.no_grad():
    inputs = torch.tensor(X_test, dtype=torch.float32)
    logits = model(inputs).squeeze()           # shape: (n_samples,)
    prob_pos = torch.sigmoid(logits).numpy()  # probability of class 1 (Extrovert)

# --- Step 2: Compute confidence (distance from 0.5) ---
confidence = np.abs(prob_pos - 0.5)  # smaller -> less confident
low_confidence_indices = np.argsort(confidence)  # indices sorted by ascending confidence

# --- Step 3: Initialize LIME explainer ---
explainer = LimeTabularExplainer(
    training_data=X_train,
    feature_names=feature_names,
    class_names=["Introvert", "Extrovert"],
    mode="classification"
)

# --- Step 4: Loop over lowest-confidence predictions and explain ---
num_samples_to_explain = 5  # choose how many low-confidence examples to inspect

for idx in low_confidence_indices[:num_samples_to_explain]:
    print(f"\n--- Explaining test sample index {idx} ---")
    exp = explainer.explain_instance(
        data_row=X_test[idx],
        predict_fn=predict_fn,
        num_features=10
    )
    exp.show_in_notebook(show_table=True)
    print(exp.as_list())  # print top features influencing this prediction

--- Explaining test sample index 2289 ---
[('match_p <= -0.10', 0.23747182405705047), ('Social_event_attendance_isna > -0.26', -0.11161846625788964), ('alone_flag <= -0.25', 0.09681280073875231), ('Drained_after_socializing <= 0.26', 0.07035683598100916), ('Stage_fear <= 0.34', 0.04635528143338435), ('Post_frequency_isna <= -0.27', 0.04192405207700681), ('Friends_circle_size_isna <= -0.25', -0.03125515522738312), ('Drained_after_socializing_isna <= -0.26', 0.02135035958605656), ('0.25 < Time_spent_Alone <= 0.26', -0.02105197266305121), ('0.44 < social_anxiety_score <= 0.44', -0.020587823612420637)]

--- Explaining test sample index 567 ---
[('match_p <= -0.10', 0.19460217130270382), ('Social_event_attendance_isna <= -0.26', 0.10941245225520066), ('alone_flag <= -0.25', 0.10360793361005463), ('Drained_after_socializing <= 0.26', 0.06456486729162622), ('Stage_fear <= 0.34', 0.049089208049030504), ('Post_frequency_isna <= -0.27', 0.03980039960311537), ('Friends_circle_size_isna > -0.25', 0.035149732645335004), ('Friends_circle_size <= 0.23', -0.0279547594923817), ('Time_spent_Alone_isna <= -0.26', 0.024097516317101046), ('0.26 < Time_spent_Alone <= 0.27', -0.0233715175766218)]

--- Explaining test sample index 2807 ---
[('match_p <= -0.10', 0.24511080839994215), ('alone_flag <= -0.25', 0.09974116996743611), ('Social_event_attendance_isna <= -0.26', 0.09701865635761965), ('Drained_after_socializing > 0.26', -0.057837212266893354), ('Stage_fear > 0.34', -0.05253874261332231), ('Post_frequency_isna > -0.27', -0.0497780372448844), ('Friends_circle_size_isna <= -0.25', -0.028105919337016783), ('Drained_after_socializing_isna <= -0.26', 0.024461929956431816), ('Time_spent_Alone > 0.27', -0.023703035694338678), ('social_engagement_score <= -1.01', -0.022614262486260393)]

--- Explaining test sample index 2370 ---
[('match_p <= -0.10', 0.2056139223424959), ('alone_flag <= -0.25', 0.11310169649987363), ('Social_event_attendance_isna <= -0.26', 0.0983175865429373), ('Drained_after_socializing <= 0.26', 0.06110894385837618), ('Stage_fear <= 0.34', 0.04423926287999881), ('Post_frequency_isna <= -0.27', 0.03485199034944694), ('Friends_circle_size_isna > -0.25', 0.026870395235316948), ('Friends_circle_size <= 0.23', -0.024593406237715564), ('0.26 < Time_spent_Alone <= 0.27', -0.023579504844908325), ('Time_spent_Alone_isna <= -0.26', 0.021667650458118002)]

--- Explaining test sample index 1365 ---
[('match_p <= -0.10', 0.2045185885153542), ('alone_flag <= -0.25', 0.10984582566508244), ('Social_event_attendance_isna <= -0.26', 0.09950834172371004), ('Drained_after_socializing > 0.26', -0.05908224335649049), ('Stage_fear > 0.34', -0.04926667262633966), ('Post_frequency_isna <= -0.27', 0.0380842844689629), ('Friends_circle_size <= 0.23', -0.0316714328482962), ('Post_frequency <= 0.26', -0.03056284845617474), ('Friends_circle_size_isna <= -0.25', -0.029142119303534587), ('social_engagement_score <= -1.01', -0.024662282078095107)]